{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "colab": {},
    "colab_type": "code",
    "deletable": true,
    "editable": true,
    "id": "qnMpW5Y9nv2l"
   },
   "outputs": [],
   "source": [
    "# Copyright 2020 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "mHF9VCProKJN"
   },
   "source": [
    "# AI Explanations: Deploying an image model\n",
    "\n",
    "<table align=\"left\">\n",
    "  <td>\n",
    "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/ml-on-gcp/blob/master/tutorials/explanations/ai-explanations-image.ipynb\">\n",
    "      <img src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Colab logo\"> Run in Colab\n",
    "    </a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://github.com/GoogleCloudPlatform/ml-on-gcp/tree/master/tutorials/explanations/ai-explanations-image.ipynb\">\n",
    "      <img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\">\n",
    "      View on GitHub\n",
    "    </a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "hZzRVxNtH-zG"
   },
   "source": [
    "## Overview\n",
    "\n",
    "This tutorial shows how to train a Keras classification model on image data and deploy it to the AI Platform Explanations service to get feature attributions on your deployed model.\n",
    "\n",
    "If you've already got a trained model and want to deploy it to AI Explanations, skip to the **Export the model as a TF 1 SavedModel** section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "iN69d4D9Flrh"
   },
   "source": [
    "### Dataset\n",
    "\n",
    "The dataset used for this tutorial is the [flowers dataset](https://www.tensorflow.org/datasets/catalog/tf_flowers) from [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/overview)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "Su2qu-4CW-YH"
   },
   "source": [
    "### Objective\n",
    "\n",
    "The goal of this tutorial is to train a model on a simple image dataset (flower classification) to understand how you can use AI Explanations with image models. For image models, AI Explanations returns an image with the pixels highlighted that signaled your model's prediction the most.\n",
    "\n",
    "This tutorial focuses more on deploying the model to AI Platform with Explanations than on the design of the model itself. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "912RD_3fxGeH"
   },
   "source": [
    "### Costs\n",
    "\n",
    "This tutorial uses billable components of Google Cloud Platform (GCP):\n",
    "\n",
    "* AI Platform for:\n",
    "  * Prediction\n",
    "  * Explanation: AI Explanations comes at no extra charge to prediction prices. However, explanation requests take longer to process than normal predictions, so heavy usage of Explanations along with auto-scaling may result in more nodes being started and thus more charges\n",
    "* Cloud Storage for:\n",
    "  * Storing model files for deploying to Cloud AI Platform\n",
    "\n",
    "Learn about [AI Platform\n",
    "pricing](https://cloud.google.com/ml-engine/docs/pricing) and [Cloud Storage\n",
    "pricing](https://cloud.google.com/storage/pricing), and use the [Pricing\n",
    "Calculator](https://cloud.google.com/products/calculator/)\n",
    "to generate a cost estimate based on your projected usage."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "rgLXkyHEvTVD"
   },
   "source": [
    "## Before you begin\n",
    "\n",
    "**Make sure you're running this notebook in a GPU runtime if you have that option. In Colab, select Runtime --> Change runtime type**\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "avDUUQEGTnUo"
   },
   "source": [
    "This tutorial assumes you are running the notebook either in **Colab** or **Cloud AI Platform Notebooks**."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "i2qsxysTVc-l"
   },
   "source": [
    "### Set up your GCP project\n",
    "\n",
    "**The following steps are required, regardless of your notebook environment.**\n",
    "\n",
    "1. [Select or create a GCP project.](https://console.cloud.google.com/cloud-resource-manager)\n",
    "\n",
    "2. [Make sure that billing is enabled for your project.](https://cloud.google.com/billing/docs/how-to/modify-project)\n",
    "\n",
    "3. [Enable the AI Platform Training & Prediction and Compute Engine APIs.](https://console.cloud.google.com/flows/enableapi?apiid=ml.googleapis.com,compute_component)\n",
    "\n",
    "4. Enter your project ID in the cell below. Then run the  cell to make sure the\n",
    "Cloud SDK uses the right project for all the commands in this notebook.\n",
    "\n",
    "**Note**: Jupyter runs lines prefixed with `!` as shell commands, and it interpolates Python variables prefixed with `$` into these commands."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "colab": {},
    "colab_type": "code",
    "deletable": true,
    "editable": true,
    "id": "4qxwBA4RM9Lu"
   },
   "outputs": [],
   "source": [
    "PROJECT_ID=\"[your-project-id]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "TSy-f05IO4LB"
   },
   "source": [
    "### Authenticate your GCP account\n",
    "\n",
    "**If you are using AI Platform Notebooks**, your environment is already\n",
    "authenticated. Skip this step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "fZQUrHdXNJnk"
   },
   "source": [
    "**If you are using Colab**, run the cell below and follow the instructions\n",
    "when prompted to authenticate your account via oAuth."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "deletable": true,
    "editable": true,
    "id": "W9i6oektpgld",
    "tags": [
     "no_execute"
    ]
   },
   "outputs": [],
   "source": [
    "import sys, os\n",
    "import warnings\n",
    "import googleapiclient\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' \n",
    "# If you are running this notebook in Colab, follow the\n",
    "# instructions to authenticate your GCP account. This provides access to your\n",
    "# Cloud Storage bucket and lets you submit training jobs and prediction\n",
    "# requests.\n",
    "\n",
    "if 'google.colab' in sys.modules:\n",
    "  from google.colab import auth as google_auth\n",
    "  google_auth.authenticate_user()\n",
    "  !gcloud config set project $PROJECT_ID\n",
    "  try:\n",
    "    %tensorflow_version 1.x\n",
    "  except Exception:\n",
    "    pass\n",
    "  import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "tT061irlJwkg"
   },
   "source": [
    "### Create a Cloud Storage bucket\n",
    "\n",
    "**The following steps are required, regardless of your notebook environment.**\n",
    "\n",
    "When you submit a training job using the Cloud SDK, you upload a Python package\n",
    "containing your training code to a Cloud Storage bucket. AI Platform runs\n",
    "the code from this package. In this tutorial, AI Platform also saves the\n",
    "trained model that results from your job in the same bucket. You can then\n",
    "create an AI Platform model version based on this output in order to serve\n",
    "online predictions.\n",
    "\n",
    "Set the name of your Cloud Storage bucket below. It must be unique across all\n",
    "Cloud Storage buckets. \n",
    "\n",
    "You may also change the `REGION` variable, which is used for operations\n",
    "throughout the rest of this notebook. Make sure to [choose a region where Cloud\n",
    "AI Platform services are\n",
    "available](https://cloud.google.com/ml-engine/docs/tensorflow/regions). You may\n",
    "not use a Multi-Regional Storage bucket for training with AI Platform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "deletable": true,
    "editable": true,
    "id": "bTxmbDg1I0x1"
   },
   "outputs": [],
   "source": [
    "BUCKET_NAME = PROJECT_ID + \"_flowers_model\"\n",
    "REGION = \"us-central1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "fsmCk2dwJnLZ"
   },
   "source": [
    "**Only if your bucket doesn't already exist**: Run the following cell to create your Cloud Storage bucket."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "deletable": true,
    "editable": true,
    "id": "160PRO3aJqLD"
   },
   "outputs": [],
   "source": [
    "! gsutil mb -l $REGION gs://$BUCKET_NAME"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PyxoF-iqqD1t"
   },
   "source": [
    "### Import libraries\n",
    "\n",
    "Import the libraries we'll be using in this tutorial. This tutorial has been tested with **TensorFlow versions 1.14 and 1.15**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MEDlLSWK15UL"
   },
   "outputs": [],
   "source": [
    "import math, json, random\n",
    "import numpy as np\n",
    "import PIL\n",
    "import tensorflow as tf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from base64 import b64encode\n",
    "\n",
    "\n",
    "print(\"Tensorflow version \" + tf.__version__)\n",
    "AUTO = tf.data.experimental.AUTOTUNE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "aRVMEU2Qshm4"
   },
   "source": [
    "## Downloading and preprocessing the flowers dataset\n",
    "\n",
    "In this section you'll download the flower images (in this dataset they are `TFRecords`), use the `tf.data` API to create a data input pipeline, and split the data into training and validation sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "v7HLNsvekxvz"
   },
   "outputs": [],
   "source": [
    "GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-192x192-2/*.tfrec'\n",
    "IMAGE_SIZE = [192, 192]\n",
    "\n",
    "BATCH_SIZE = 32 \n",
    "\n",
    "VALIDATION_SPLIT = 0.19\n",
    "CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # do not change, maps to the labels in the data (folder names)\n",
    "\n",
    "# Split data files between training and validation\n",
    "filenames = tf.gfile.Glob(GCS_PATTERN)\n",
    "random.shuffle(filenames)\n",
    "split = int(len(filenames) * VALIDATION_SPLIT)\n",
    "training_filenames = filenames[split:]\n",
    "validation_filenames = filenames[:split]\n",
    "print(\"Pattern matches {} data files. Splitting dataset into {} training files and {} validation files\".format(len(filenames), len(training_filenames), len(validation_filenames)))\n",
    "validation_steps = int(3670 // len(filenames) * len(validation_filenames)) // BATCH_SIZE\n",
    "steps_per_epoch = int(3670 // len(filenames) * len(training_filenames)) // BATCH_SIZE\n",
    "print(\"With a batch size of {}, there will be {} batches per training epoch and {} batch(es) per validation run.\".format(BATCH_SIZE, steps_per_epoch, validation_steps))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pHsCuE2DDRqX"
   },
   "source": [
    "The following cell contains some image visualization utility functions. This code isn't essential to training or deploying the model. \n",
    "\n",
    "If you're running this from Colab the cell is hidden. You can look at the code by right clicking on the cell --> \"Form\" --> \"Show form\" if you'd like to see it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "v5hLZEH3De5-"
   },
   "outputs": [],
   "source": [
    "#@title display utilities [RUN ME]\n",
    "\n",
    "def dataset_to_numpy_util(dataset, N):\n",
    "  dataset = dataset.batch(N)\n",
    "  \n",
    "  if tf.executing_eagerly():\n",
    "    # In eager mode, iterate in the Datset directly.\n",
    "    for images, labels in dataset:\n",
    "      numpy_images = images.numpy()\n",
    "      numpy_labels = labels.numpy()\n",
    "      break;\n",
    "      \n",
    "  else: # In non-eager mode, must get the TF note that \n",
    "        # yields the nextitem and run it in a tf.Session.\n",
    "    get_next_item = dataset.make_one_shot_iterator().get_next()\n",
    "    with tf.Session() as ses:\n",
    "      numpy_images, numpy_labels = ses.run(get_next_item)\n",
    "\n",
    "  return numpy_images, numpy_labels\n",
    "\n",
    "def title_from_label_and_target(label, correct_label):\n",
    "  label = np.argmax(label, axis=-1)  # one-hot to class number\n",
    "  correct_label = np.argmax(correct_label, axis=-1) # one-hot to class number\n",
    "  correct = (label == correct_label)\n",
    "  return \"{} [{}{}{}]\".format(CLASSES[label], str(correct), ', shoud be ' if not correct else '',\n",
    "                              CLASSES[correct_label] if not correct else ''), correct\n",
    "\n",
    "def display_one_flower(image, title, subplot, red=False):\n",
    "    plt.subplot(subplot)\n",
    "    plt.axis('off')\n",
    "    plt.imshow(image)\n",
    "    plt.title(title, fontsize=16, color='red' if red else 'black')\n",
    "    return subplot+1\n",
    "  \n",
    "def display_9_images_from_dataset(dataset):\n",
    "  subplot=331\n",
    "  plt.figure(figsize=(13,13))\n",
    "  images, labels = dataset_to_numpy_util(dataset, 9)\n",
    "  for i, image in enumerate(images):\n",
    "    title = CLASSES[np.argmax(labels[i], axis=-1)]\n",
    "    subplot = display_one_flower(image, title, subplot)\n",
    "    if i >= 8:\n",
    "      break;\n",
    "              \n",
    "  plt.tight_layout()\n",
    "  plt.subplots_adjust(wspace=0.1, hspace=0.1)\n",
    "  plt.show()\n",
    "  \n",
    "def display_9_images_with_predictions(images, predictions, labels):\n",
    "  subplot=331\n",
    "  plt.figure(figsize=(13,13))\n",
    "  for i, image in enumerate(images):\n",
    "    title, correct = title_from_label_and_target(predictions[i], labels[i])\n",
    "    subplot = display_one_flower(image, title, subplot, not correct)\n",
    "    if i >= 8:\n",
    "      break;\n",
    "              \n",
    "  plt.tight_layout()\n",
    "  plt.subplots_adjust(wspace=0.1, hspace=0.1)\n",
    "  plt.show()\n",
    "  \n",
    "def display_training_curves(training, validation, title, subplot):\n",
    "  if subplot%10==1: # set up the subplots on the first call\n",
    "    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')\n",
    "    plt.tight_layout()\n",
    "  ax = plt.subplot(subplot)\n",
    "  ax.set_facecolor('#F8F8F8')\n",
    "  ax.plot(training)\n",
    "  ax.plot(validation)\n",
    "  ax.set_title('model '+ title)\n",
    "  ax.set_ylabel(title)\n",
    "  ax.set_xlabel('epoch')\n",
    "  ax.legend(['train', 'valid.'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5lguLcIRDqj-"
   },
   "source": [
    "### Read images and labels from TFRecords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "IiQ0GryzDvE9"
   },
   "outputs": [],
   "source": [
    "def read_tfrecord(example):\n",
    "    features = {\n",
    "        \"image\": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring\n",
    "        \"class\": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar\n",
    "        \"one_hot_class\": tf.io.VarLenFeature(tf.float32),\n",
    "    }\n",
    "    example = tf.parse_single_example(example, features)\n",
    "    image = tf.image.decode_jpeg(example['image'], channels=3)\n",
    "    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range\n",
    "    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size will be needed for TPU\n",
    "    class_label = tf.cast(example['class'], tf.int32)\n",
    "    one_hot_class = tf.sparse.to_dense(example['one_hot_class'])\n",
    "    one_hot_class = tf.reshape(one_hot_class, [5])\n",
    "    return image, one_hot_class\n",
    "\n",
    "def load_dataset(filenames):\n",
    "  # Read data from TFRecords\n",
    "\n",
    "  dataset = tf.data.Dataset.from_tensor_slices(filenames)\n",
    "  dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=16, num_parallel_calls=AUTO) # faster\n",
    "  dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)\n",
    "  return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wPcefe7uEOuT"
   },
   "source": [
    "In the following cell, we'll use a visualization utility function we defined above to preview some flower images with their associated labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PJfoNx-_EOKI"
   },
   "outputs": [],
   "source": [
    "display_9_images_from_dataset(load_dataset(training_filenames))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "0pbkgRIJFPdJ"
   },
   "source": [
    "### Create training and validation datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jqtZupcNFRSa"
   },
   "outputs": [],
   "source": [
    "def get_batched_dataset(filenames):\n",
    "  dataset = load_dataset(filenames)\n",
    "  dataset = dataset.cache() # This dataset fits in RAM\n",
    "  dataset = dataset.repeat()\n",
    "  dataset = dataset.shuffle(2048)\n",
    "  dataset = dataset.batch(BATCH_SIZE)\n",
    "  dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)\n",
    "  # For proper ordering of map/batch/repeat/prefetch, see Dataset performance guide: https://www.tensorflow.org/guide/performance/datasets\n",
    "  return dataset\n",
    "\n",
    "def get_training_dataset():\n",
    "  return get_batched_dataset(training_filenames)\n",
    "\n",
    "def get_validation_dataset():\n",
    "  return get_batched_dataset(validation_filenames)\n",
    "\n",
    "some_flowers, some_labels = dataset_to_numpy_util(load_dataset(validation_filenames), 8*20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "8zr6lj66UlMn"
   },
   "source": [
    "## Build, train, and evaluate the model\n",
    "\n",
    "In this section we'll define the layers of our model using the Keras Sequential model API. Then we'll run training and evaluation, and finally run some test predictions on the local model.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "deletable": true,
    "editable": true,
    "id": "Icz22E69smnD"
   },
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Conv2D(kernel_size=3, filters=16, padding='same', activation='relu', input_shape=[*IMAGE_SIZE, 3]),\n",
    "    tf.keras.layers.Conv2D(kernel_size=3, filters=30, padding='same', activation='relu'),\n",
    "    tf.keras.layers.MaxPooling2D(pool_size=2),\n",
    "    tf.keras.layers.Conv2D(kernel_size=3, filters=60, padding='same', activation='relu'),\n",
    "    tf.keras.layers.MaxPooling2D(pool_size=2),\n",
    "    tf.keras.layers.Conv2D(kernel_size=3, filters=90, padding='same', activation='relu'),\n",
    "    tf.keras.layers.MaxPooling2D(pool_size=2),\n",
    "    tf.keras.layers.Conv2D(kernel_size=3, filters=110, padding='same', activation='relu'),\n",
    "    tf.keras.layers.MaxPooling2D(pool_size=2),\n",
    "    tf.keras.layers.Conv2D(kernel_size=3, filters=130, padding='same', activation='relu'),\n",
    "    tf.keras.layers.Conv2D(kernel_size=1, filters=40, padding='same', activation='relu'),\n",
    "    tf.keras.layers.GlobalAveragePooling2D(),\n",
    "    tf.keras.layers.Dense(5, activation='softmax')\n",
    "])\n",
    "\n",
    "model.compile(\n",
    "  optimizer='adam',\n",
    "  loss= 'categorical_crossentropy',\n",
    "  metrics=['accuracy'])\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "z3PBnvR3FCaY"
   },
   "source": [
    "### Train the model\n",
    "\n",
    "Train this on a GPU if you have access (in Colab, from the menu select **Runtime** --> **Change runtime type**). On a CPU, it'll take ~30 minutes to run training. On a GPU, it takes ~5 minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "vxZryg4xmdy0"
   },
   "outputs": [],
   "source": [
    "EPOCHS = 20 # Train for 60 epochs for higher accuracy, 20 should get you ~75%\n",
    "\n",
    "history = model.fit(get_training_dataset(), steps_per_epoch=steps_per_epoch, epochs=EPOCHS,\n",
    "                    validation_data=get_validation_dataset(), validation_steps=validation_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wJ9ks5bGGBwz"
   },
   "source": [
    "### Get predictions on local model and visualize them"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "yMEsR851VDZb"
   },
   "outputs": [],
   "source": [
    "# Randomize the input so that you can execute multiple times to change results\n",
    "permutation = np.random.permutation(8*20)\n",
    "some_flowers, some_labels = (some_flowers[permutation], some_labels[permutation])\n",
    "\n",
    "predictions = model.predict(some_flowers, batch_size=16)\n",
    "evaluations = model.evaluate(some_flowers, some_labels, batch_size=16)\n",
    "  \n",
    "print(np.array(CLASSES)[np.argmax(predictions, axis=-1)].tolist())\n",
    "print('[val_loss, val_acc]', evaluations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qzCCDL1CZFx6"
   },
   "outputs": [],
   "source": [
    "display_9_images_with_predictions(some_flowers, predictions, some_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kV_NEAQwwH0e"
   },
   "source": [
    "## Export the model as a TF 1 SavedModel\n",
    "\n",
    "AI Explanations currently supports TensorFlow 1.x. In order to deploy our model in a format compatible with AI Explanations, we'll follow the steps below to convert our Keras model to a TF Estimator, and then use the `export_saved_model` method to generate the SavedModel and save it in GCS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2CRJTE7BXPJE"
   },
   "outputs": [],
   "source": [
    "## Convert our Keras model to an estimator and then export to SavedModel\n",
    "keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir='savedmodel_export')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mLq6rRf6Ge3X"
   },
   "source": [
    "The `decode_img_bytes` function below handles converting image bytes (the format our served model will expect) to floats: a `[192,192,3]` dimensional matrix that our model is expecting. For image explanations models, we recommend this approach rather than sending an image as a float array from the client."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2DPLpyk5XXHT"
   },
   "outputs": [],
   "source": [
    "def decode_img_bytes(img_bytes, height, width, color_depth):\n",
    "  features = tf.squeeze(img_bytes, axis=1, name='input_squeeze')\n",
    "  float_pixels = tf.map_fn(\n",
    "    lambda img_string: tf.io.decode_image(\n",
    "        img_string, \n",
    "        channels=color_depth,\n",
    "        dtype=tf.float32\n",
    "    ),\n",
    "    features,\n",
    "    dtype=tf.float32,\n",
    "    name='input_convert'\n",
    "  )\n",
    "\n",
    "  tf.Tensor.set_shape(float_pixels, (None, height, width, color_depth))\n",
    "  float_pixels = tf.identity(float_pixels, name='input_pixels')\n",
    "\n",
    "  return float_pixels\n",
    "\n",
    "def serving_input_receiver_fn():\n",
    "  img_bytes = tf.placeholder(shape=(None,1), dtype=tf.string)\n",
    "  img_float = decode_img_bytes(img_bytes, 192,192, 3)\n",
    "  return tf.estimator.export.ServingInputReceiver({'conv2d_input': img_float}, {'conv2d_input': img_bytes})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Yn8hEAHiXaE7"
   },
   "outputs": [],
   "source": [
    "export_path = keras_estimator.export_saved_model(\n",
    "  'gs://' + BUCKET_NAME + '/explanations',\n",
    "  serving_input_receiver_fn\n",
    ").decode('utf-8')\n",
    "print(\"Model exported to: \", export_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "S1mTufKyX8rr"
   },
   "outputs": [],
   "source": [
    "!saved_model_cli show --dir $export_path --all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "E8g1JW03HO5v"
   },
   "source": [
    "## Generate the metadata for AI Explanations\n",
    "\n",
    "In order to deploy this model to Cloud Explanations, we need to create an `explanation_metadata.json` file with information about our model inputs, outputs, and baseline. \n",
    "\n",
    "For image models, using `[0,1]` as your input baseline represents black and white images. In this case we're using `np.random` to generate the baseline because our training images contain a lot of black and white (i.e. daisy petals)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ixQHycu-ahEB"
   },
   "outputs": [],
   "source": [
    "random_baseline = np.random.rand(192,192,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8KY_wPitX9Gl"
   },
   "outputs": [],
   "source": [
    "explanation_metadata = {\n",
    "    \"inputs\": {\n",
    "      \"data\": {\n",
    "        \"input_tensor_name\": \"input_pixels:0\",\n",
    "        \"modality\": \"image\",\n",
    "        \"input_baselines\": [random_baseline.tolist()]\n",
    "      }\n",
    "    },\n",
    "    \"outputs\": {\n",
    "      \"probability\": {\n",
    "        \"output_tensor_name\": \"dense/Softmax:0\"\n",
    "      }\n",
    "    },\n",
    "  \"framework\": \"tensorflow\"\n",
    "  }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ovef6vCwRCAH"
   },
   "outputs": [],
   "source": [
    "# Write the json to a local file\n",
    "with open('explanation_metadata.json', 'w') as output_file:\n",
    "  json.dump(explanation_metadata, output_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PjIiTSTiYAK9"
   },
   "outputs": [],
   "source": [
    "# Copy this file into the GCS location with our SavedModel assets\n",
    "!gsutil cp explanation_metadata.json $export_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "0fyT3PQrH1PJ"
   },
   "source": [
    "## Deploy model to AI Explanations\n",
    "\n",
    "In this step we'll use the `gcloud` CLI to deploy our model to AI Explanations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "X1Oj9C2TIM7L"
   },
   "source": [
    "### Create the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lgd4HjJTIM7P"
   },
   "outputs": [],
   "source": [
    "MODEL = 'flowers'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RuG90JjQIM7S"
   },
   "outputs": [],
   "source": [
    "# Create the model if it doesn't exist yet (you only need to run this once)\n",
    "!gcloud ai-platform models create $MODEL --enable-logging --region=$REGION"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7CJycA-1IM7b"
   },
   "source": [
    "### Create explainable model versions\n",
    "\n",
    "For image models, we offer two choices for explanation methods: \n",
    "* Integrated Gradients (IG)\n",
    "* XRAI \n",
    "\n",
    "You can find more info on each method in the [documentation](TODO). Below, we'll show you how to deploy a version with both so that you can compare results. **If you already know which explanation method you'd like to use, you can deploy one version and skip the code blocks for the other method.**\n",
    "\n",
    "Creating the version will take ~5-10 minutes. Note that your first deploy may take longer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QQZgzF2I25jf"
   },
   "source": [
    "#### Deploy an explainable model with Integrated Gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QrpL2cJ7IM7c"
   },
   "outputs": [],
   "source": [
    "# Each time you create a version the name should be unique\n",
    "IG_VERSION = 'v_ig'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "r9zZbEAeIM7h"
   },
   "outputs": [],
   "source": [
    "# Create the version with gcloud\n",
    "!gcloud beta ai-platform versions create $IG_VERSION --region=$REGION \\\n",
    "--model $MODEL \\\n",
    "--origin $export_path \\\n",
    "--runtime-version 1.15 \\\n",
    "--framework TENSORFLOW \\\n",
    "--python-version 3.7 \\\n",
    "--machine-type n1-standard-4 \\\n",
    "--explanation-method integrated-gradients \\\n",
    "--num-integral-steps 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "XORWSOl3IM7i"
   },
   "outputs": [],
   "source": [
    "# Make sure the IG model deployed correctly. State should be `READY` in the following log\n",
    "!gcloud ai-platform versions describe $IG_VERSION --model $MODEL --region=$REGION"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vgvDwajk3Zus"
   },
   "source": [
    "#### Deploy an explainable model with XRAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PT6WwCDr3c4Z"
   },
   "outputs": [],
   "source": [
    "# Each time you create a version the name should be unique\n",
    "XRAI_VERSION = 'v_xrai'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZoRixfJp3c4f"
   },
   "outputs": [],
   "source": [
    "# Create the XRAI version with gcloud\n",
    "!gcloud beta ai-platform versions create $XRAI_VERSION --region=$REGION \\\n",
    "--model $MODEL \\\n",
    "--origin $export_path \\\n",
    "--runtime-version 1.15 \\\n",
    "--framework TENSORFLOW \\\n",
    "--python-version 3.7 \\\n",
    "--machine-type n1-standard-4 \\\n",
    "--explanation-method xrai \\\n",
    "--num-integral-steps 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "khcdj5osM6ml"
   },
   "outputs": [],
   "source": [
    "# Make sure the XRAI model deployed correctly. State should be `READY` in the following log\n",
    "!gcloud ai-platform versions describe $XRAI_VERSION --model $MODEL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MWQ72uqOIelF"
   },
   "source": [
    "## Get predictions and explanations on deployed model\n",
    "\n",
    "Here we'll prepare some test images to send to our model. Then we'll use the AI Platform Prediction API to get the model's predicted class along with the explanation for each image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_9p4ESRVH0rk"
   },
   "outputs": [],
   "source": [
    "# Download test flowers from public bucket\n",
    "!mkdir flowers\n",
    "!gsutil -m cp gs://flowers_model/test_flowers/* ./flowers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LvaUEUd0ToaV"
   },
   "outputs": [],
   "source": [
    "# Resize the images to what our model is expecting (192,192)\n",
    "test_filenames = []\n",
    "\n",
    "for i in os.listdir('flowers'):\n",
    "  img_path = 'flowers/' + i\n",
    "  with PIL.Image.open(img_path) as ex_img:\n",
    "    resize_img = ex_img.resize([192,192])\n",
    "    resize_img.save(img_path)\n",
    "    test_filenames.append(img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ingRZrUOJF-A"
   },
   "outputs": [],
   "source": [
    "# Prepare our prediction JSON to send to our Cloud model\n",
    "instances = []\n",
    "\n",
    "for i in test_filenames:\n",
    "  with open(i, 'rb') as example_img:\n",
    "    b64str = b64encode(example_img.read()).decode('utf-8')\n",
    "    instances.append({'conv2d_input': [{'b64': b64str}]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "US8Tg-WE9QId"
   },
   "source": [
    "The `predict_json` method below calls our deployed model with the specified image data, model name, and version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TDC7D5fh7h98"
   },
   "outputs": [],
   "source": [
    "# This is adapted from a sample in the docs\n",
    "# Find it here: https://cloud.google.com/ai-platform/prediction/docs/online-predict#python\n",
    "\n",
    "def predict_json(project, model, instances, version=None):\n",
    "    \"\"\"Send json data to a deployed model for prediction.\n",
    "\n",
    "    Args:\n",
    "        project (str): project where the AI Platform Model is deployed.\n",
    "        model (str): model name.\n",
    "        instances ([Mapping[str: Any]]): Keys should be the names of Tensors\n",
    "            your deployed model expects as inputs. Values should be datatypes\n",
    "            convertible to Tensors, or (potentially nested) lists of datatypes\n",
    "            convertible to tensors.\n",
    "        version: str, version of the model to target.\n",
    "    Returns:\n",
    "        Mapping[str: any]: dictionary of prediction results defined by the\n",
    "            model.\n",
    "    \"\"\"\n",
    "\n",
    "    service = googleapiclient.discovery.build('ml', 'v1')\n",
    "    name = 'projects/{}/models/{}'.format(project, model)\n",
    "\n",
    "    if version is not None:\n",
    "        name += '/versions/{}'.format(version)\n",
    "\n",
    "    response = service.projects().explain(\n",
    "        name=name,\n",
    "        body={'instances': instances}\n",
    "    ).execute()\n",
    "\n",
    "    if 'error' in response:\n",
    "        raise RuntimeError(response['error'])\n",
    "\n",
    "    return response"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5-B1VlyszaSz"
   },
   "source": [
    "### Make an AI Explanations request with gcloud\n",
    "\n",
    "First we'll look at the explanations results for IG, then we'll compare with XRAI. \n",
    "\n",
    "**If you only deployed one model above, run only the cell for that explanation method.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8bhIjK_v6CH3"
   },
   "outputs": [],
   "source": [
    "# IG EXPLANATIONS\n",
    "ig_response = predict_json(PROJECT_ID, MODEL, instances, IG_VERSION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ivs78yzb4P-f"
   },
   "outputs": [],
   "source": [
    "# XRAI EXPLANATIONS\n",
    "xrai_response = predict_json(PROJECT_ID, MODEL, instances, XRAI_VERSION)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vu88JCaONgWN"
   },
   "source": [
    "### See our model's predicted classes without explanations\n",
    "\n",
    "First, let's preview the images and see what our model predicted for them. Why did the model predict these classes? We'll see explanations in the next section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cFRb8JpcTUgF"
   },
   "outputs": [],
   "source": [
    "from io import BytesIO\n",
    "import matplotlib.image as mpimg\n",
    "import base64\n",
    "\n",
    "# Note: change the `ig_response` variable below if you didn't deploy an IG model\n",
    "for i,val in enumerate(ig_response['explanations']):\n",
    "    class_name = CLASSES[val['attributions_by_label'][0]['label_index']]\n",
    "    confidence_score = str(round(val['attributions_by_label'][0]['example_score'] * 100, 3)) + '%'\n",
    "    print('Predicted class: ' + class_name + '\\n' + 'Confidence score: ' + confidence_score)\n",
    "    \n",
    "    img = instances[i]['conv2d_input'][0]['b64']\n",
    "    im = BytesIO(base64.b64decode(img))\n",
    "    i = mpimg.imread(im, format='JPG')\n",
    "    plt.imshow(i, interpolation='nearest')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pswFW9SeUfqj"
   },
   "source": [
    "### Visualize the images with AI Explanations\n",
    "\n",
    "Now let's look at the explanations.  \n",
    "\n",
    "The images returned show the explanations for **only the top class predicted by the model**. This means that if one of our model's predictions is incorrect, the pixels you see highlighted are for the _incorrect class_. For example, if the model predicted rose when it should have predicted tulip, you'll see explanations for why the model thought this image was a rose.\n",
    "\n",
    "First, we'll visualize the attributions for our **Integrated Gradients version**. Currently, the highlighted pixels returned from AI Explanations show the top 60% of pixels that contributed to the model's prediction. The pixels we'll see after running the cell below show us the pixels that signaled the model's prediction most."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HwkScD4kUhsU"
   },
   "outputs": [],
   "source": [
    "import io\n",
    "\n",
    "for idx, flower in enumerate(ig_response['explanations']):\n",
    "  predicted_flower = CLASSES[flower['attributions_by_label'][0]['label_index']]\n",
    "  confidence = flower['attributions_by_label'][0]['example_score']\n",
    "  print('Predicted flower: ', predicted_flower)\n",
    "  b64str = flower['attributions_by_label'][0]['attributions']['data']['b64_jpeg']\n",
    "  i = base64.b64decode(b64str)\n",
    "  i = io.BytesIO(i)\n",
    "  i = mpimg.imread(i, format='JPG')\n",
    "\n",
    "  plt.imshow(i, interpolation='nearest')\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NNPUGCNn6A-j"
   },
   "source": [
    "Let's compare this with the image explanations we get from our XRAI version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tEnHKBlW6JZc"
   },
   "outputs": [],
   "source": [
    "for idx, flower in enumerate(xrai_response['explanations']):\n",
    "  predicted_flower = CLASSES[flower['attributions_by_label'][0]['label_index']]\n",
    "  confidence = flower['attributions_by_label'][0]['example_score']\n",
    "  print('Predicted flower: ', predicted_flower)\n",
    "  b64str = flower['attributions_by_label'][0]['attributions']['data']['b64_jpeg']\n",
    "  i = base64.b64decode(b64str)\n",
    "  i = io.BytesIO(i)\n",
    "  i = mpimg.imread(i, format='JPG')\n",
    "\n",
    "  plt.imshow(i, interpolation='nearest')\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BZiM7kywQy6j"
   },
   "source": [
    "## Sanity check our explanations\n",
    "\n",
    "To better make sense of the feature attributions we're getting, we should compare them with our model's baseline. In the case of image models, the `baseline_score` returned by AI Explanations is the score our model would give an image input with the baseline we specified. The baseline will be different for each class in our model. In other words, every time your model predicts `tulip` as the top class, you'll see the same baseline score. \n",
    "\n",
    "In this case, we used a baseline image of `np.random` randomly generated values. If you'd like the baseline for your model to be solid black and white images instead, pass `[0,1]` as the value to `input_baselines` in your `explanation_metadata.json` file above.\n",
    "\n",
    "If the `baseline_score` is very close to the value of `example_score`, the highlighted pixels may not be meaningful. \n",
    "\n",
    "Below we'll calculate the difference between `baseline_score` and `example_score` for the 3 test images above.\n",
    "\n",
    "Note that the score values for classification models are _probabilities_: the confidence your model has in its predicted class. A score of 0.90 for tulip means your model has classified the image as a tulip with 90% confidence.\n",
    "\n",
    "We're running sanity checks below on our IG model, but if you'd like to inspect your XRAI model just swap out the `ig_response` and `IG_VERSION` variables below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CSf6psVDSDrN"
   },
   "outputs": [],
   "source": [
    "for i,val in enumerate(ig_response['explanations']):\n",
    "  baseline_score = val['attributions_by_label'][0]['baseline_score']\n",
    "  predicted_score = val['attributions_by_label'][0]['example_score']\n",
    "  print('Baseline score: ', baseline_score) \n",
    "  print('Predicted score: ', predicted_score)\n",
    "  print('Predicted - Baseline: ', predicted_score - baseline_score, '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aLt8L73w3G9N"
   },
   "source": [
    "As another sanity check, we'll also look at the explanations for this model's baseline image: an image array of randomly generated values using `np.random`. First, we'll convert the same `np.random` baseline array we generated above to a base64 string and preview it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Gx276Jwp3Vh_"
   },
   "outputs": [],
   "source": [
    "# Convert our baseline from above to a base64 string\n",
    "rand_test_img = PIL.Image.fromarray((random_baseline * 255).astype('uint8'))\n",
    "buffer = BytesIO()\n",
    "rand_test_img.save(buffer, format=\"BMP\")\n",
    "new_image_string = base64.b64encode(buffer.getvalue()).decode(\"utf-8\")\n",
    "\n",
    "# Preview it\n",
    "plt.imshow(rand_test_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4CuercuM3bB2"
   },
   "outputs": [],
   "source": [
    "# Save the image to a variable in the format our model is expecting\n",
    "sanity_check_img = {'conv2d_input': [{'b64': new_image_string}]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "o16qhm793m2j"
   },
   "outputs": [],
   "source": [
    "# Make the prediction request\n",
    "sanity_check_resp = predict_json(PROJECT_ID, MODEL, sanity_check_img, IG_VERSION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rL_LyIrr3vU5"
   },
   "outputs": [],
   "source": [
    "# View explanations on the baseline random image\n",
    "sanity_check_img = base64.b64decode(sanity_check_resp['explanations'][0]['attributions_by_label'][0]['attributions']['data']['b64_jpeg'])\n",
    "sanity_check_img = io.BytesIO(sanity_check_img)\n",
    "sanity_check_img = mpimg.imread(sanity_check_img, format='JPG')\n",
    "\n",
    "plt.imshow(sanity_check_img, interpolation='nearest')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2jJYsrkRCsMB"
   },
   "source": [
    "The difference between your model's predicted score and the baseline score for this image should be close to 0. Run the following cell to confirm. If there is a difference between these two values you may need to increase the number of integral steps used when you deploy your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WpgZEInNCz4G"
   },
   "outputs": [],
   "source": [
    "baseline_score = sanity_check_resp['explanations'][0]['attributions_by_label'][0]['baseline_score']\n",
    "example_score = sanity_check_resp['explanations'][0]['attributions_by_label'][0]['example_score']\n",
    "\n",
    "print(abs(baseline_score - example_score))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "x27DXeUGzb-M"
   },
   "source": [
    "## Cleaning up\n",
    "\n",
    "To clean up all GCP resources used in this project, you can [delete the GCP\n",
    "project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.\n",
    "\n",
    "Alternatively, you can clean up individual resources by running the following\n",
    "commands:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "deletable": true,
    "editable": true,
    "id": "no210oWF68Uk"
   },
   "outputs": [],
   "source": [
    "# Delete model version resource\n",
    "!gcloud ai-platform versions delete $IG_VERSION --quiet --model $MODEL\n",
    "!gcloud ai-platform versions delete $XRAI_VERSION --quiet --model $MODEL\n",
    "\n",
    "# Delete model resource\n",
    "!gcloud ai-platform models delete $MODEL --quiet\n",
    "\n",
    "# Delete Cloud Storage objects that were created\n",
    "!gsutil -m rm -r $BUCKET_NAME"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "3F2g4OjbJ3gZ"
   },
   "source": [
    "If your Cloud Storage bucket doesn't contain any other objects and you would like to delete it, run `gsutil rm -r gs://$BUCKET_NAME`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "deletable": true,
    "editable": true,
    "id": "K0UXLWaBJnrY"
   },
   "source": [
    "## What's next?\n",
    "\n",
    "To learn more about AI Explanations, check out the resources here.\n",
    "\n",
    "* [AI Explanations documentation](https://cloud.google.com/ai-platform/prediction/docs/ai-explanations/overview)\n",
    "* [AI Explanations whitepaper](https://storage.googleapis.com/cloud-ai-whitepapers/AI%20Explainability%20Whitepaper.pdf)\n",
    "* [Integrated gradients paper](https://arxiv.org/abs/1703.01365)\n",
    "* [XRAI paper](https://arxiv.org/abs/1906.02825)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "ai-explanations-image.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
